#!/usr/bin/env python
"""
Find initial mutual exclusive cliques by aligning
input reads against itself.
"""

import os.path as op
import time
import logging
import networkx as nx
from pbcore.io import FastaReader
from pbtranscript.Utils import real_upath, execute
from pbtranscript.ice_daligner import DalignerRunner
import pbtranscript.ice.pClique as pClique
from pbtranscript.ice.IceUtils import blasr_against_ref, daligner_against_ref

__author__ = 'etseng@pacificbiosciences.com'


class IceInit(object):
    """Iterative clustering and error correction."""
    def __init__(self, readsFa, qver_get_func, qvmean_get_func,
                 ice_opts, sge_opts):

        self.readsFa = readsFa
        self.ice_opts = ice_opts
        self.sge_opts = sge_opts

        self.ice_opts.detect_cDNA_size(readsFa)

        self.uc = self.init_cluster_by_clique(
            readsFa=readsFa,
            qver_get_func=qver_get_func,
            qvmean_get_func=qvmean_get_func,
            ice_opts=self.ice_opts, sge_opts=self.sge_opts)

    # version using BLASR; fallback if daligner fails
    def _align_withBLASR(self, queryFa, targetFa, outFN, ice_opts, sge_opts):
        """Align input reads against itself using BLASR."""
        if op.exists(outFN):
            logging.info("{0} already exists. No need to run BLASR.".format(outFN))
        else:
            cmd = "blasr {q} ".format(q=real_upath(queryFa)) + \
                  "{t} ".format(t=real_upath(targetFa)) + \
                  "-m 5 --maxLCPLength 15 " + \
                  "--nproc {cpu} ".format(cpu=sge_opts.blasr_nproc) + \
                  "--maxScore {score} ".format(score=ice_opts.maxScore) + \
                  "--bestn {n} --nCandidates {n} ".format(n=ice_opts.bestn) + \
                  "--out {o} ".format(o=real_upath(outFN)) + \
                  "1>/dev/null 2>/dev/null"
            logging.info("Calling {cmd}".format(cmd=cmd))
            execute(cmd)

    # align with DALIGNER
    def _align_withDALIGNER(self, queryFa, output_dir):
        """Align input reads against itself using DALIGNER."""
        # run this locally
        runner = DalignerRunner(query_filename=queryFa, target_filename=queryFa,
                                query_converted=False, target_converted=False,
                                is_FL=True, same_strand_only=True,
                                use_sge=False, sge_opts=None,
                                cpus=4)
        runner.run(min_match_len=self.ice_opts.low_cDNA_size,
                   output_dir=output_dir,
                   sensitive_mode=self.ice_opts.sensitive_mode)
        return runner

    def _makeGraphFromM5(self, m5FN, qver_get_func, qvmean_get_func, ice_opts):
        """Construct a graph from a BLASR M5 file."""
        alignGraph = nx.Graph()

        for r in blasr_against_ref(output_filename=m5FN,
                                   is_FL=True,
                                   sID_starts_with_c=False,
                                   qver_get_func=qver_get_func,
                                   qvmean_get_func=qvmean_get_func,
                                   ece_penalty=ice_opts.ece_penalty,
                                   ece_min_len=ice_opts.ece_min_len):
            if r.qID == r.cID:
                continue # self hit, ignore
            if r.ece_arr is not None:
                logging.debug("adding edge {0},{1}".format(r.qID, r.cID))
                alignGraph.add_edge(r.qID, r.cID)
        return alignGraph

    def _makeGraphFromLA4Ice(self, runner, qver_get_func, qvmean_get_func, ice_opts):
        """Construct a graph from a LA4Ice output file."""
        alignGraph = nx.Graph()

        for la4ice_filename in runner.la4ice_filenames:
            count = 0
            start_t = time.time()
            for r in daligner_against_ref(
                    query_dazz_handler=runner.query_dazz_handler,
                    target_dazz_handler=runner.target_dazz_handler,
                    la4ice_filename=la4ice_filename,
                    is_FL=True, sID_starts_with_c=False,
                    qver_get_func=qver_get_func, qvmean_get_func=qvmean_get_func,
                    qv_prob_threshold=.03, ece_min_len=ice_opts.ece_min_len,
                    ece_penalty=ice_opts.ece_penalty,
                    same_strand_only=True, no_qv_or_aln_checking=False):
                if r.qID == r.cID:
                    continue # self hit, ignore
                if r.ece_arr is not None:
                    alignGraph.add_edge(r.qID, r.cID)
                    count += 1
            logging.debug("total {0} edges added from {1}; took {2} sec"
                          .format(count, la4ice_filename, time.time()-start_t))
        return alignGraph

    def _findCliques(self, alignGraph, readsFa):
        """
        Find all mutually exclusive cliques within the graph, with decreased
        size.

        alignGraph - a graph, each node represent a read and each edge
        represents an alignment between two end points.

        Return a dictionary of clique indices and nodes.
            key = index of a clique
            value = nodes within a clique
        Cliques are ordered by their size descendingly: index up, size down
        Reads which are not included in any cliques will be added as cliques
        of size 1.
        """
        uc = {}    # To keep cliques found
        used = []  # nodes within any cliques
        ind = 0    # index of clique to discover

        deg = alignGraph.degree().items()
        # Sort tuples of (node, degree) by degree, descendingly
        deg.sort(key=lambda x: x[1], reverse=True)
        for d in deg:
            node = d[0]  # node which has the largest degree in alignGraph
            if node not in alignGraph:
                continue
            # just get the immediate neighbors since we're looking for perfect
            # cliques
            subGraph = alignGraph.subgraph([node] + alignGraph.neighbors(node))
            subNodes = subGraph.nodes()
            # Convert from networkx.Graph to a sparse matrix
            S, H = pClique.convert_graph_connectivity_to_sparse(
                subGraph, subNodes)
            # index of the 'node' in the sub-graph
            seed_i = subNodes.index(node)
            # Grasp a clique from subGraph, and return indices of clique nodes
            # setting gamma=0.8 means to find quasi-0.8-cliques!
            tQ = pClique.grasp(S, H, gamma=0.8, maxitr=5, given_starting_node=seed_i)
            if len(tQ) > 0:
                c = [subNodes[i] for i in tQ]  # nodes in the clique
                uc[ind] = c  # Add the clique to uc
                ind += 1
                used += c    # Add clique nodes to used
                # Remove clique nodes from alignGraph and continue
                alignGraph.remove_nodes_from(c)

        with FastaReader(readsFa) as reader:
            for r in reader:
                rid = r.name.split()[0]
                if rid not in used:
                    uc[ind] = [rid]
                    ind += 1
        return uc

    def init_cluster_by_clique(self, readsFa, qver_get_func, qvmean_get_func,
                               ice_opts, sge_opts):
        """
        Only called once and in the very beginning, when (probably a subset)
        of sequences are given to generate the initial cluster.

        readsFa --- initial fasta filename, probably called *_split00.fasta
        qver_get_func --- function that returns QVs on reads
        qvmean_get_func --- function that returns the mean QV on reads
        bestn --- parameter in BLASR, higher helps in finding perfect
            cliques but bigger output
        nproc, maxScore --- parameter in BLASR, set maxScore appropriate
            to input transcript length
        ece_penalty, ece_min_len --- parameter in isoform hit calling

        Self-blasr input then iteratively find all mutually exclusive
            cliques (in decreasing size)
        Returns dict of cluster_index --> list of seqids
        which is the 'uc' dict that can be used by IceIterative
        """
        alignGraph = None
        try:
            runner = self._align_withDALIGNER(queryFa=readsFa,
                                              output_dir=op.dirname(readsFa))
            alignGraph = self._makeGraphFromLA4Ice(runner=runner,
                                                   qver_get_func=qver_get_func,
                                                   qvmean_get_func=qvmean_get_func,
                                                   ice_opts=ice_opts)
            runner.clean_run()
        except RuntimeError:  # daligner probably crashed, fall back to blasr
            outFN = readsFa + '.self.blasr'
            self._align_withBLASR(queryFa=readsFa, targetFa=readsFa, outFN=outFN,
                                  ice_opts=ice_opts, sge_opts=sge_opts)
            alignGraph = self._makeGraphFromM5(m5FN=outFN,
                                               qver_get_func=qver_get_func,
                                               qvmean_get_func=qvmean_get_func,
                                               ice_opts=ice_opts)

        uc = self._findCliques(alignGraph=alignGraph, readsFa=readsFa)
        return uc

